Super resolution on medical images:

Set-up:

Imports:
In [12]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import pandas as pd
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
import tensorflow as tf
from torchvision.models import vgg16_bn
import torchvision.transforms as transforms
from skimage import measure

from  modified_model_fx import *
In [13]:
import warnings
warnings.filterwarnings('ignore')
In [14]:
trans1 = transforms.ToTensor()
trans = transforms.ToPILImage()

Set CUDA:

In [15]:
torch.cuda.set_device(3)

Import data:

Input images:

  • HR images: (3, 502, 672) original (3, 1004, 1344) cut in 4 pieces
  • LR images: (3, 250, 334) original (3, 500, 669) cut in 4 pieces

Paths to data:

In [16]:
path = Path('../../../../../SCRATCH2/marvande/data/train/HR/')

# path to HR data:
path_hr = path / 'HR_patches_train/jpg_images'

# path where we create the LR data:
path_lr = path / 'small-250/train'

# path to MR data (of same size as HR):
path_mr = path / 'small-502/train'

path_mr_test = path / 'small-250/test'

# path to original LR data:
path_lr_or = path/'HR_patches_resized/jpg_images'

assert path.exists(), f"need dataset @ {path}"
assert path_hr.exists()

Have a look at what type of data is in those folders. First, we look at our HR images from path_hr.

In [17]:
il = ImageList.from_folder(path_hr)
ImageList.from_folder(path_hr)
Out[17]:
ImageList (6168 items)
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/HR_patches_train/jpg_images

Create LR and MR data:

From this HR data, we create LR versions in path_lr and path_mr for training. LR images will be of size (3, 250, 334) while MR images of the same size as the HR but of lower quality (MR images are used in the second phase of training), thus (3, 502, 672).

In [18]:
def resize_one(fn, i, path, size):
    """resize_one: resizes images to input size and saves them in path, 
    quality is lowered as to get LR images. 
    
    """
    dest = path / fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    img = img.resize(size, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)
In [19]:
# create smaller image sets the first time this nb is run:
sets = [(path_lr, (334, 250)), (path_mr, (672, 502))]
for p, size in sets:
    if not p.exists():
        print(f"resizing to {size} into {p}")
        parallel(partial(resize_one, path=p, size=size), il.items)

Have a look at the new LR and MR images and their sizes.

In [20]:
print('LR:')
print(ImageList.from_folder(path_lr))
print('MR:')
print(ImageList.from_folder(path_mr))
LR:
ImageList (2056 items)
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train
MR:
ImageList (2056 items)
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-502/train

Training and validation data:

As the model's first phase is on images of the LR size, we create the first batch of training data of size (3, 250, 334). So HR images are downsized to (3, 250, 334) and both HR and LR are transformed using several transformations (see modified_model_fx.py). Those data augmentation techniques help avoid overfitting during training.

In [25]:
# set image size and batch size to which data is transformed:
bs, size = 15, (250, 334)
arch = models.resnet34

src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)
src
Out[25]:
ItemLists;

Train: ImageImageList (1851 items)
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Valid: ImageImageList (205 items)
Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334),Image (3, 250, 334)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Test: None

Change statistics.

In [22]:
def get_data(bs, size):
    """
    get_data: creates training and validation data from LR and HR. 
    downsizes HR to LR size and applies transformations to both. 
    """

    #label_from_func: apply func to every input to get its label.
    # defining a custom function to extract the labels
    data = src.label_from_func(lambda x: path_hr / x.relative_to(path_lr))

    #apply data transformations,
    data = data.transform(tfms, size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats,do_y=True)
    data.c = 3
    return data
In [24]:
data = get_data(bs, size)
data
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-24-94c1f03ed0be> in <module>
----> 1 data = get_data(bs, size)
      2 data

<ipython-input-22-5dd5e9c3fa8a> in get_data(bs, size)
     10 
     11     #apply data transformations,
---> 12     data = data.transform(tfms, size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats,do_y=True)
     13     data.c = 3
     14     return data

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/data_block.py in transform(self, tfms, **kwargs)
    503         if not tfms: tfms=(None,None)
    504         assert is_listy(tfms) and len(tfms) == 2, "Please pass a list of two lists of transforms (train and valid)."
--> 505         self.train.transform(tfms[0], **kwargs)
    506         self.valid.transform(tfms[1], **kwargs)
    507         if self.test: self.test.transform(tfms[1], **kwargs)

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/data_block.py in transform(self, tfms, tfm_y, **kwargs)
    725         if tfm_y is None: tfm_y = self.tfm_y
    726         tfms_y = None if tfms is None else list(filter(lambda t: getattr(t, 'use_on_y', True), listify(tfms)))
--> 727         if tfm_y: _check_kwargs(self.y, tfms_y, **kwargs)
    728         self.tfms,self.tfmargs  = tfms,kwargs
    729         self.tfm_y,self.tfms_y,self.tfmargs_y = tfm_y,tfms_y,kwargs

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/data_block.py in _check_kwargs(ds, tfms, **kwargs)
    591     if (tfms is None or len(tfms) == 0) and len(kwargs) == 0: return
    592     if len(ds.items) >= 1:
--> 593         x = ds[0]
    594         try: x.apply_tfms(tfms, **kwargs)
    595         except Exception as e:

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/data_block.py in __getitem__(self, idxs)
    118         "returns a single item based if `idxs` is an integer or a new `ItemList` object if `idxs` is a range."
    119         idxs = try_int(idxs)
--> 120         if isinstance(idxs, Integral): return self.get(idxs)
    121         else: return self.new(self.items[idxs], inner_df=index_row(self.inner_df, idxs))
    122 

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/vision/data.py in get(self, i)
    269     def get(self, i):
    270         fn = super().get(i)
--> 271         res = self.open(fn)
    272         self.sizes[i] = res.size
    273         return res

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/vision/data.py in open(self, fn)
    265     def open(self, fn):
    266         "Open image in `fn`, subclass and overwrite for custom behavior."
--> 267         return open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open)
    268 
    269     def get(self, i):

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/vision/image.py in open_image(fn, div, convert_mode, cls, after_open)
    396     with warnings.catch_warnings():
    397         warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
--> 398         x = PIL.Image.open(fn).convert(convert_mode)
    399     if after_open: x = after_open(x)
    400     x = pil2tensor(x,np.float32)

~/image_super_resolution/image_res/lib/python3.7/site-packages/PIL/Image.py in open(fp, mode)
   2841 
   2842     if filename:
-> 2843         fp = builtins.open(filename, "rb")
   2844         exclusive_fp = True
   2845 

FileNotFoundError: [Errno 2] No such file or directory: '/SCRATCH2/marvande/data/train/HR/HR_patches_train/jpg_images/0124_[47360,11368]_part_1_1_.tif'

Have a look at a few data from the validation data. Left are the LR while right the corresponding Hr images.

In [ ]:
# lr 
x = data.one_batch(ds_type=DatasetType.Valid)[0][4]
# hr
y = data.one_batch(ds_type=DatasetType.Valid)[1][4]
print('LR data shape {} and HR shape {}'.format(list(x.shape), list(y.shape)))

MSE = mse_mult_chann(x.numpy(), y.numpy())
NMSE = measure.compare_nrmse(x.numpy(), y.numpy())
SSIM = ssim_mult_chann(x.numpy(), y.numpy())
print('MSE: {}, NMSE: {}, SSIM: {}'.format(MSE, NMSE, SSIM))

plot_single_image(x, '', (10,10))
plot_single_image(y, '', (10,10))
In [26]:
! cp ../../../../../SCRATCH2/marvande/data/train/HR/small-502/train/models/2a_mod_norm.pth ../../../../../SCRATCH2/marvande/data/train/HR/models/
cp: cannot stat '../../../../../SCRATCH2/marvande/data/train/HR/small-502/train/models/2a_mod_norm.pth': No such file or directory
In [13]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(20, 20))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-13-cce01b1e9089> in <module>
----> 1 data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(20, 20))

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/basic_data.py in show_batch(self, rows, ds_type, reverse, **kwargs)
    184     def show_batch(self, rows:int=5, ds_type:DatasetType=DatasetType.Train, reverse:bool=False, **kwargs)->None:
    185         "Show a batch of data in `ds_type` on a few `rows`."
--> 186         x,y = self.one_batch(ds_type, True, True)
    187         if reverse: x,y = x.flip(0),y.flip(0)
    188         n_items = rows **2 if self.train_ds.x._square_show else rows

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/basic_data.py in one_batch(self, ds_type, detach, denorm, cpu)
    167         w = dl.num_workers
    168         dl.num_workers = 0
--> 169         try:     x,y = next(iter(dl))
    170         finally: dl.num_workers = w
    171         if detach: x,y = to_detach(x,cpu=cpu),to_detach(y,cpu=cpu)

~/image_super_resolution/image_res/lib/python3.7/site-packages/fastai/basic_data.py in __iter__(self)
     73     def __iter__(self):
     74         "Process and returns items from `DataLoader`."
---> 75         for b in self.dl: yield self.proc_batch(b)
     76 
     77     @classmethod

~/image_super_resolution/image_res/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~/image_super_resolution/image_res/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
--> 387             data = _utils.pin_memory.pin_memory(data)
    388         return data
    389 

~/image_super_resolution/image_res/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/image_super_resolution/image_res/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py in <listcomp>(.0)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/image_super_resolution/image_res/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     45 def pin_memory(data):
     46     if isinstance(data, torch.Tensor):
---> 47         return data.pin_memory()
     48     elif isinstance(data, string_classes):
     49         return data

RuntimeError: cuda runtime error (46) : all CUDA-capable devices are busy or unavailable at /pytorch/aten/src/THC/THCCachingHostAllocator.cpp:278

Feature loss:

Create loss metrics used in the model.

In [62]:
def gram_matrix(x):
    """
    Gram matrix of a set of vectors in an inner 
    product space is the Hermitian matrix of inner products
    """
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

Select the data of an image and use it to compute its Gram matrix.

In [63]:
im = data.valid_ds[0][1]
t = im.data
t = torch.stack([t,t])
In [64]:
gram_matrix(t)
Out[64]:
tensor([[[0.3176, 0.3176, 0.3176],
         [0.3176, 0.3176, 0.3176],
         [0.3176, 0.3176, 0.3176]],

        [[0.3176, 0.3176, 0.3176],
         [0.3176, 0.3176, 0.3176],
         [0.3176, 0.3176, 0.3176]]])

We define a base loss as the L1 loss (F.l1_loss).

In [65]:
base_loss = F.l1_loss

Construct a pre-trained vgg16 model (Very Deep Convolutional Networks for Large-Scale Image Recognition). VGG is another Convolutional Neural Network (CNN) architecture devised in 2014, the 16 layer version is utilised in the loss function for training this model. VGG model. a network pretrained on ImageNet, is used to evaluate the generator model’s loss.

Further, we set requires_grad to False as this is useful when you want to freeze part of your model, or you know in advance that you are not going to use gradients w.r.t. some parameters.

The head of the VGG model is the final layers shown as fully connected and softmax in the above diagram. This head is ignored and the loss function uses the intermediate activations in the backbone of the network, which represent the feature detections.

Those activations can be found by looking through the VGG model to find all the max pooling layers. These are where the grid size changes and features are detected. So we need to select those layers.

In [66]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

Select the layer IDs of MaxPool2d blocks:

In [67]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]
Out[67]:
([5, 12, 22, 32, 42],
 [ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True)])

Create the feature loss from the model and layer ids selected above.

Source:

Main points:

  • Loss functions using these techniques can be used during the training of U-Net based model architectures
  • Feature loss: loss function used is similar to the loss function in the the paper, using VGG-16 but also combined with pixel mean squared error loss loss and gram matrix style loss
  • The training of a model can use this loss function based on the VGG model’s activations. The loss function remains fixed throughout the training unlike the critic part of a GAN
  • Feature loss: Feature map has 256 channels by 28 by 28. The activations at the same layer for the (target) original image and the generated image are compared using mean squared error or the least absolute error (L1) error for the base loss. These are feature losses. This error function uses L1 error. This allows the loss function to know what features are in the target ground truth image and to evaluate how well the model’s prediction’s features match these rather than only comparing pixel difference. This allows the model being trained with this loss function to produce much finer detail in the generated/predicted features and output.
  • Gram matrix style loss: A gram matrix defines a style with respect to specific content. By calculating the gram matrix for each feature activation in the target/ground truth image, it allows the style of that feature to be defined. If the same gram matrix is calculated from the activations of the predictions, the two can be compared to calculate how close the style of the feature prediction is to the target/ground truth image. A gram matrix is the matrix multiplication of the each of the activations and the activation matrix’s transpose. This enables the model to learn and generate predictions of images whose features look correct in their style and in context, with the end result looking more convincing and appear closer or the same as the target/ground truth.

Predictions from models trained with this loss function: The generated predictions from trained models using loss functions based on these techniques have both convincing fine detail and style. That style and fine detail may be different aspects of image quality be predicting fine pixel detail or the predicting correct colours.

In [68]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        #feat losses
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        #gram: 
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()
In [69]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

Train

In [73]:
lr = 1e-3
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=10)
In [72]:
# delete all tensors and free cache:
for obj in gc.get_objects():
    if torch.is_tensor(obj):
        del obj
torch.cuda.empty_cache()
gc.collect()
#learn.destroy()

#get free memory (in MBs) for the currently selected gpu id, after emptying the cache
print(
    'free memory (in MBs) for the currently selected gpu id, after emptying the cache: ',
    gpu_mem_get_free_no_cache())

print(
    'used memory (in MBs) for the currently selected gpu id, after emptying the cache:',
    gpu_mem_get_used_no_cache())

gpu_mem_get_all()
this Learner object self-destroyed - it still exists, but no longer usable
free memory (in MBs) for the currently selected gpu id, after emptying the cache:  7254
used memory (in MBs) for the currently selected gpu id, after emptying the cache: 4958
Out[72]:
[GPUMemory(total=12212, free=8334, used=3878),
 GPUMemory(total=12196, free=10448, used=1747),
 GPUMemory(total=12196, free=7219, used=4976),
 GPUMemory(total=12212, free=7254, used=4958)]
In [74]:
wd = 1e-3
learn = unet_learner(data,
                     arch,
                     wd=wd,
                     loss_func=feat_loss,
                     callback_fns=LossMetrics,
                     blur=True,
                     norm_type=NormType.Weight)
# garbage collection:
gc.collect()
Out[74]:
9
In [75]:
learn.lr_find()
learn.recorder.plot()
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time

25.68% [95/370 01:55<05:34 2.6148]
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [55]:
do_fit('1a_mod_norm', slice(lr*10))
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.745564 0.703226 0.072518 0.133972 0.127637 0.024699 0.175697 0.157595 0.011107 07:51
1 0.696248 0.682734 0.069820 0.130308 0.124251 0.024019 0.177779 0.145463 0.011095 07:51
2 0.706219 0.683751 0.088931 0.128138 0.122131 0.023521 0.167327 0.142927 0.010776 07:51
3 0.702797 0.647316 0.073275 0.126877 0.121393 0.023076 0.150515 0.141607 0.010573 07:51
4 0.677807 0.646809 0.072707 0.128532 0.121571 0.023378 0.154164 0.135981 0.010476 07:51
5 0.672722 0.663200 0.089357 0.127985 0.121119 0.023242 0.156748 0.134380 0.010368 07:51
6 0.653321 0.634007 0.074245 0.126788 0.119412 0.023058 0.147423 0.132760 0.010322 07:51
7 0.652571 0.649495 0.069889 0.129649 0.123050 0.023757 0.152688 0.139833 0.010630 07:50
8 0.649285 0.625629 0.075241 0.127799 0.120113 0.022935 0.139549 0.129826 0.010166 07:49
9 0.605637 0.587562 0.066955 0.125243 0.117836 0.022505 0.121519 0.123578 0.009927 07:49
In [57]:
learn.unfreeze()
In [58]:
do_fit('1b_mod_norm', slice(1e-5,lr))
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.609209 0.588111 0.067097 0.125175 0.117747 0.022468 0.122010 0.123666 0.009947 08:06
1 0.606251 0.588426 0.067203 0.125117 0.117754 0.022474 0.122239 0.123696 0.009942 08:06
2 0.606750 0.583781 0.066779 0.125026 0.117596 0.022413 0.119552 0.122519 0.009896 08:06
3 0.605110 0.582180 0.066731 0.124527 0.117133 0.022378 0.119887 0.121650 0.009873 08:06
4 0.600497 0.582473 0.066775 0.124738 0.117312 0.022364 0.119504 0.121907 0.009872 08:07
5 0.599919 0.583216 0.067232 0.124702 0.117376 0.022431 0.119637 0.121964 0.009874 08:04
6 0.603773 0.582166 0.066981 0.123890 0.116472 0.022203 0.120088 0.122710 0.009822 08:05
7 0.599580 0.576537 0.066625 0.123798 0.116562 0.022212 0.116819 0.120752 0.009769 08:06
8 0.592966 0.580542 0.067110 0.124194 0.116774 0.022324 0.118824 0.121478 0.009838 08:06
9 0.594584 0.571241 0.066267 0.123617 0.116246 0.022171 0.113641 0.119561 0.009738 08:06

Phase 2:

In [27]:
torch.cuda.empty_cache()
gc.collect()
#learn.destroy()

#get free memory (in MBs) for the currently selected gpu id, after emptying the cache
print(
    'free memory (in MBs) for the currently selected gpu id, after emptying the cache: ',
    gpu_mem_get_free_no_cache())

print(
    'used memory (in MBs) for the currently selected gpu id, after emptying the cache:',
    gpu_mem_get_used_no_cache())

gpu_mem_get_all()
free memory (in MBs) for the currently selected gpu id, after emptying the cache:  10880
used memory (in MBs) for the currently selected gpu id, after emptying the cache: 1332
Out[27]:
[GPUMemory(total=12212, free=9840, used=2372),
 GPUMemory(total=12196, free=7717, used=4478),
 GPUMemory(total=12196, free=10534, used=1661),
 GPUMemory(total=12212, free=10880, used=1332)]
In [28]:
new_size = (502, 672)
bs = 4
data = get_data(bs,new_size)

wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                     blur=True, norm_type=NormType.Weight)

lr = 1e-3
# garbage collection: 
gc.collect();
In [29]:
learn.data = data
learn.freeze()
gc.collect()
Out[29]:
20
In [30]:
data
Out[30]:
ImageDataBunch;

Train: LabelList (5552 items)
x: ImageImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
y: ImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Valid: LabelList (616 items)
x: ImageImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
y: ImageList
Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672),Image (3, 502, 672)
Path: ../../../../../SCRATCH2/marvande/data/train/HR/small-250/train;

Test: None
In [31]:
learn.load('1b_mod_norm');
In [32]:
do_fit('2a_mod_norm')
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.832925 0.896616 0.091470 0.182345 0.171421 0.027638 0.255280 0.158550 0.009911 38:28
1 0.815595 0.874326 0.091203 0.179802 0.168967 0.027145 0.244489 0.153000 0.009721 38:17
2 0.798485 0.847727 0.091752 0.179523 0.168102 0.027114 0.224606 0.147073 0.009558 38:22
3 0.781896 0.847235 0.091324 0.177996 0.166898 0.026825 0.225327 0.149394 0.009472 38:24
4 0.784576 0.820942 0.092667 0.177002 0.166586 0.026806 0.209845 0.138598 0.009438 38:25
5 0.773251 0.809337 0.091778 0.176275 0.165426 0.026528 0.202888 0.137133 0.009309 38:23
6 0.764628 0.780511 0.093099 0.174661 0.163290 0.026499 0.184308 0.129516 0.009138 38:23
7 0.763625 0.766990 0.093637 0.173792 0.162081 0.026341 0.176648 0.125478 0.009013 38:23
8 0.734789 0.806548 0.091125 0.174860 0.164566 0.026842 0.202648 0.137315 0.009193 38:23
9 0.710625 0.746289 0.093340 0.171752 0.160779 0.026076 0.165443 0.119996 0.008903 38:24
In [33]:
learn.unfreeze()
In [34]:
do_fit('2b_mod_norm', slice(1e-6,1e-4), pct_start=0.3)
epoch train_loss valid_loss pixel feat_0 feat_1 feat_2 gram_0 gram_1 gram_2 time
0 0.701475 0.741437 0.093549 0.171642 0.160506 0.025934 0.161429 0.119469 0.008907 39:25
1 0.708152 0.740292 0.093802 0.172408 0.160315 0.025959 0.158914 0.120012 0.008881 39:27
2 0.704648 0.740365 0.093259 0.171283 0.160305 0.025952 0.161740 0.118937 0.008890 39:27
3 0.718388 0.731713 0.094452 0.171976 0.160266 0.025964 0.154668 0.115631 0.008755 39:26
4 0.710667 0.733564 0.093322 0.171518 0.160193 0.025935 0.156675 0.117153 0.008769 39:27
5 0.689970 0.729099 0.093765 0.171401 0.159826 0.025818 0.153467 0.116054 0.008768 39:28
6 0.694874 0.725248 0.093590 0.171077 0.159403 0.025804 0.151915 0.114694 0.008764 39:29
7 0.684416 0.730636 0.092919 0.170981 0.159667 0.025848 0.156883 0.115608 0.008730 39:30
8 0.703657 0.723982 0.093868 0.170980 0.159456 0.025799 0.150847 0.114312 0.008720 39:26
9 0.694823 0.721953 0.093497 0.170710 0.159134 0.025752 0.150347 0.113813 0.008700 39:23

Testing :

Reconstruct the learn object, this time with l1_loss and not with feature_loss, why?

In [23]:
learn = None
gc.collect()
learn = unet_learner(data,
                     arch,
                     loss_func=F.l1_loss,
                     blur=True,
                     norm_type=NormType.Weight)

Prediction:

Set the sizes of the LR input images (size_lr) for testing and the HR images (size_mr).

In [24]:
size_mr = (3, 502, 672)
size_lr = (3, 250, 334)

path_test = path/'HR_patches_test/cut_images/jpeg_images/'
path_mr_test = path / 'small-250/test'


#Check free GPU RAM:
free = gpu_mem_get_free_no_cache()
print(f"using size={size}, have {free}MB of GPU RAM free")
using size=(250, 334), have 11505MB of GPU RAM free
In [25]:
# Create test images:
test_files = []

with os.scandir(path_test) as entries:
        test_files = [entry.name for entry in entries if entry.is_file()]

size = (334, 250)
for file in test_files:
    if not path_test.exists():
        img = PIL.Image.open(path_test/file)
        img = img.resize(size, resample=PIL.Image.BILINEAR).convert('RGB')
        img.save(path_mr_test/file, quality=60)

Create testing data, data_mr of size HR and data_lr of size LR:

In [26]:
data_mr = (ImageImageList.from_folder(path_test).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_test/x.name)
          .transform(tfms, size=size_mr, tfm_y=True)
          .databunch(bs=1).normalize(imagenet_stats, do_y=True))
data_mr.c = 3
data_lr = (ImageImageList.from_folder(path_mr_test).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_test/x.name)
          .transform(tfms, size=size_lr, tfm_y=True)
          .databunch(bs=1).normalize(imagenet_stats, do_y=True))
data_lr.c = 3

Load the learn object from the last phase (2b_mod) and set it's data to data_mr (why?).

In [27]:
learn.load('2b_mod_norm')
learn.data = data_mr

Select the images we are going to use for testing from data_mr and data_lr:

In [28]:
# Ground truth HR image:
gt_HR = data_mr.valid_ds.y.items[1]

# LR version of the same image:
lr = data_lr.valid_ds.x.items[1]

To be able to apply the model and predict the HR image from the LR, we need to resample the LR image to be of the same size as the HR, as the model inputs and outputs images of the same size.

In [29]:
lr_data = open_image(lr)

# resample to same size as HR:
# for pytorch, have to add a first new dimension to indicate bs = 1
# now lr_data of shape [1, 3, 250, 334]
lr_data = lr_data.data.unsqueeze(0)
lr_resized = torch.nn.functional.interpolate(lr_data, size_mr[1:],
                                             mode='bicubic')
# remove the previous added dimension
lr_resized = lr_resized.squeeze()

# plot resized LR: 
plot_single_image(lr_resized, 
                  'Resized LR from size {} to size {}'\
                  .format(list(lr_data.squeeze().shape), 
                          list(lr_resized.shape)), (10,10))

Now that we have resized the LR to the same size as the ground truth HR, we can feed it to the model and predict a new HR image.

In [30]:
# Ground truth HR: 
im_HR_gt = open_image(gt_HR)
#print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

# Prediction of model: 
p, img_pred, b = learn.predict(Image(lr_resized))
#print('Reconstructed HR shape: {}'.format(list(p.shape)))

# Assert reconstructed HR has same shape as ground truth HR:
assert(list(p.shape) == list(im_HR_gt.shape))

Visualisation:

Compare this predicted image to the ground truth HR and LR:

In [31]:
plot_3_images(lr_resized, trans1(trans(im_HR_gt.data)).numpy(), img_pred, (20,20))
In [43]:
print('PRed')
print(np.max(img_pred.numpy()))
print(np.min(img_pred.numpy()))

print('Ground truth HR')
print(np.max(trans1(trans(im_HR_gt.data)).numpy()))
print(np.min(trans1(trans(im_HR_gt.data)).numpy()))

print('LR bic')
print(np.max(lr_resized.numpy()))
print(np.min(lr_resized.numpy()))
PRed
1.0705991
0.080254346
Ground truth HR
1.0
0.13333334
LR bic
1.0230343
0.12708423

min(max_HR, trans1(trans(im_HR_gt.data)).numpy())

max(min_HR, trans1(trans(im_HR_gt.data)).numpy())

In [47]:
max_HR = np.max(trans1(trans(im_HR_gt.data)).numpy())
min_HR = np.min(trans1(trans(im_HR_gt.data)).numpy())

test_im = np.clip(a_max = max_HR,a_min = min_HR, a = img_pred.numpy())

print(mse_mult_chann(img_pred.numpy(), trans1(trans(im_HR_gt.data)).numpy()))
mse_mult_chann(test_im, trans1(trans(im_HR_gt.data)).numpy())
0.0027782882029333002
Out[47]:
0.0027779219382844107
Evaluate loss and error metrics:
In [39]:
compare_images_metrics(img_pred.numpy(),trans1(trans(im_HR_gt.data)).numpy(),lr_resized.numpy(),'')
Predicted HR (left) VS original HR (right)
MSE: 0.00277829, NMSE: 0.06309721, SSIM : 0.8016

In [34]:
p = '../../../../../SCRATCH2/marvande/data/train/HR/small-250/test/0131_[46833,16191]_part_1_-3.jpg'
pattern = "test\/(.*?)_"
substring = re.search(pattern, p).group(1)
substring
Out[34]:
'0131'
In [35]:
## Plot for several: 
LR_list = ImageList.from_folder(path_mr_test).items
HR_list = ImageList.from_folder(path_test).items
phenotypes = {'0': 'CD4', '1': 'CK','2': 'DAPI', '3': 'CD3', '4': 'FoxP3', '5': 'CD8'}

for i in range(len(LR_list[0:5])):
    lr_data = open_image(LR_list[i])
    
    # get file string
    pattern = "test\/(.*?)\.jpg"
    substring = re.search(pattern, str(LR_list[i])).group(1)
    file_name = substring+'.jpg'
    
    # get patient number:
    pattern = "test\/(.*?)_"
    patient = re.search(pattern, str(LR_list[i])).group(1)
    
    # get phenotype number:
    pattern2 = "_\-(.*?)\.jpg"
    phenotype_numb = re.search(pattern2,file_name).group(1)
    phenotype = phenotypes[phenotype_numb]
    
    # get location number:
    pattern2 = "\[(.*?)\]"
    location = re.search(pattern2,file_name).group(1)        
        
    # resample to same size as HR:
    # for pytorch, have to add a first new dimension to indicate bs = 1
    # now lr_data of shape [1, 3, 250, 334]
    lr_data = lr_data.data.unsqueeze(0)
    lr_resized = torch.nn.functional.interpolate(lr_data, size_mr[1:],
                                                     mode='bicubic')
    # remove the previous added dimension
    lr_resized = lr_resized.squeeze()

    # corresponding ground truth HR: 
    im_HR_gt = open_image(HR_list[i])
    #print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

    # Prediction of model: 
    p, img_pred, b = learn.predict(Image(lr_resized))
    #print('Reconstructed HR shape: {}'.format(list(p.shape)))
    # Assert reconstructed HR has same shape as ground truth HR:
    assert(list(p.shape) == list(im_HR_gt.shape))

    gt_HR_nd = trans1(trans(im_HR_gt.data)).numpy()
    pred_HR_nd = img_pred.numpy()
    
    compare_images_metrics(pred_HR_nd,gt_HR_nd, 'phenotype: {}, patient: {}, location: [{}]'.format(phenotype, patient, location))
Predicted HR (left) VS original HR (right)
MSE: 0.00058833, NMSE: 0.02852871, SSIM : 0.9320
phenotype: CD3, patient: 0131, location: [46833,16191]
Predicted HR (left) VS original HR (right)
MSE: 0.00065341, NMSE: 0.02886579, SSIM : 0.9396
phenotype: CK, patient: 0131, location: [42912,16664]
Predicted HR (left) VS original HR (right)
MSE: 0.00019014, NMSE: 0.01438317, SSIM : 0.9760
phenotype: CK, patient: 0131, location: [42242,13869]
Predicted HR (left) VS original HR (right)
MSE: 0.00033488, NMSE: 0.01939397, SSIM : 0.9697
phenotype: CD8, patient: 0131, location: [41114,18347]
Predicted HR (left) VS original HR (right)
MSE: 0.00277829, NMSE: 0.06309721, SSIM : 0.8016
phenotype: DAPI, patient: 0131, location: [45930,8132]

Loss and error metrics:

In [38]:
# Create test table with results:

def testing_images(path_lr, path_hr, show_im = False):
    phenotypes = {'0': 'CD4', '1': 'CK','2': 'DAPI', '3': 'CD3', '4': 'FoxP3', '5': 'CD8'}
    MSE,NMSE, SSIM = [], [],[]
    LR_list = ImageList.from_folder(path_lr).items
    HR_list = ImageList.from_folder(path_hr).items
    file_names, locations, patients, phenotypes_list = [], [], [], []
    LR_MSE, LR_NMSE, LR_SSIM= [], [], []
    
    for i in range(len(LR_list[0:20])):
        lr_data = open_image(LR_list[i])
        
        pattern = "test\/(.*?)\.jpg"
        substring = re.search(pattern, str(LR_list[i])).group(1)
        file_names.append(substring+'.jpg')
        file_name = substring+'.jpg'
        
        # get patient number:
        pattern = "test\/(.*?)_"
        patients.append(re.search(pattern, str(LR_list[i])).group(1))

        # get phenotype number:
        pattern2 = "_\-(.*?)\.jpg"
        phenotype_numb = re.search(pattern2,file_name).group(1)
        phenotypes_list.append(phenotypes[phenotype_numb])

        # get location number:
        pattern3 = "\[(.*?)\]"
        locations.append('[{}]'.format(re.search(pattern3,file_name).group(1))  )

        # resample to same size as HR:
        # for pytorch, have to add a first new dimension to indicate bs = 1
        # now lr_data of shape [1, 3, 250, 334]
        lr_data = lr_data.data.unsqueeze(0)
        lr_resized = torch.nn.functional.interpolate(lr_data, size_mr[1:],
                                                     mode='bicubic')
        # remove the previous added dimension
        lr_resized = lr_resized.squeeze()

        # corresponding ground truth HR: 
        im_HR_gt = open_image(HR_list[i])
        #print('HR ground thruth shape: {}'.format(list(im_HR_gt.shape)))

        # Prediction of model: 
        p, img_pred, b = learn.predict(Image(lr_resized))
        #print('Reconstructed HR shape: {}'.format(list(p.shape)))

        # Assert reconstructed HR has same shape as ground truth HR:
        assert(list(p.shape) == list(im_HR_gt.shape))

        gt_HR_nd = trans1(trans(im_HR_gt.data)).numpy()
        pred_HR_nd = img_pred.numpy()

        MSE.append(mse_mult_chann(gt_HR_nd, pred_HR_nd))
        NMSE.append(measure.compare_nrmse(gt_HR_nd, pred_HR_nd))
        SSIM.append(ssim_mult_chann(gt_HR_nd, pred_HR_nd))
        
        LR_MSE.append(mse_mult_chann(lr_resized.numpy(), pred_HR_nd))
        LR_NMSE.append(measure.compare_nrmse(lr_resized.numpy(), pred_HR_nd))
        LR_SSIM.append(ssim_mult_chann(lr_resized.numpy(), pred_HR_nd))

    return pd.DataFrame(data = {'file':file_names,'patient':patients, 
                                'location':locations, 
                                'phenotype':phenotypes_list,
                                'MSE':MSE, 
                                'NMSE':NMSE, 'SSIM':SSIM, 
                               'LR_MSE': LR_MSE, 
                               'LR_NMSE':LR_NMSE,
                               'LR_SSIM':LR_SSIM})
      
testing_images(path_mr_test, path_test)
Out[38]:
file patient location phenotype MSE NMSE SSIM LR_MSE LR_NMSE LR_SSIM
0 0131_[46833,16191]_part_1_-3.jpg 0131 [46833,16191] CD3 0.000588 0.028529 0.932013 0.000239 0.018185 0.974292
1 0131_[42912,16664]_part_2_-1.jpg 0131 [42912,16664] CK 0.000653 0.028866 0.939602 0.000325 0.020380 0.972099
2 0131_[42242,13869]_part_1_-1.jpg 0131 [42242,13869] CK 0.000190 0.014383 0.975974 0.000128 0.011796 0.984911
3 0131_[41114,18347]_part_2_-5.jpg 0131 [41114,18347] CD8 0.000335 0.019394 0.969690 0.000224 0.015862 0.981349
4 0131_[45930,8132]_part_4_-2.jpg 0131 [45930,8132] DAPI 0.002778 0.063097 0.801640 0.000614 0.029738 0.955995
5 0131_[49577,17995]_part_2_-5.jpg 0131 [49577,17995] CD8 0.001919 0.050938 0.807614 0.000434 0.024246 0.956979
6 0131_[40968,17689]_part_3_-0.jpg 0131 [40968,17689] CD4 0.000227 0.015651 0.980994 0.000198 0.014615 0.983728
7 0131_[47748,18478]_part_4_-5.jpg 0131 [47748,18478] CD8 0.000335 0.019255 0.967275 0.000174 0.013884 0.984839
8 0131_[43076,8951]_part_2_-3.jpg 0131 [43076,8951] CD3 0.001084 0.035522 0.850774 0.000321 0.019350 0.957733
9 0131_[50072,7868]_part_3_-5.jpg 0131 [50072,7868] CD8 0.000733 0.028273 0.902026 0.000292 0.017852 0.964223
10 0131_[44467,14410]_part_1_-4.jpg 0131 [44467,14410] FoxP3 0.000477 0.023085 0.938726 0.000175 0.013973 0.978849
11 0131_[42242,13869]_part_2_-3.jpg 0131 [42242,13869] CD3 0.000331 0.019296 0.971792 0.000179 0.014168 0.985651
12 0131_[48485,17537]_part_3_-2.jpg 0131 [48485,17537] DAPI 0.002870 0.060582 0.742069 0.000531 0.026108 0.944792
13 0131_[43515,10518]_part_4_-0.jpg 0131 [43515,10518] CD4 0.003640 0.065416 0.773562 0.000523 0.024851 0.953420
14 0131_[40237,15610]_part_1_-5.jpg 0131 [40237,15610] CD8 0.000251 0.016656 0.971835 0.000165 0.013509 0.982252
15 0131_[40734,12991]_part_4_-3.jpg 0131 [40734,12991] CD3 0.000370 0.019870 0.944357 0.000098 0.010204 0.986466
16 0131_[39373,13898]_part_3_-5.jpg 0131 [39373,13898] CD8 0.007627 0.099724 0.564984 0.000950 0.035366 0.914037
17 0131_[46833,16191]_part_1_-2.jpg 0131 [46833,16191] DAPI 0.000174 0.013634 0.975392 0.000096 0.010134 0.986607
18 0131_[50072,7868]_part_1_-1.jpg 0131 [50072,7868] CK 0.002130 0.050434 0.812682 0.000505 0.024578 0.950177
19 0131_[39813,12991]_part_2_-4.jpg 0131 [39813,12991] FoxP3 0.002872 0.057647 0.679240 0.000440 0.022591 0.933212
In [2]:
import numpy as np
mse = np.mean([0.002872, 0.002130,0.000174, 0.007627,  0.000370, 0.000251, 0.003640, 0.002870, 0.000331,
       0.000477, 0.000733, 0.001084, 0.000335, 0.000227,0.001919, 0.002778,  0.000335,0.000190,
       0.000653, 0.000588])
Out[2]:
0.0014792
In [5]:
r1 = [1,'0131_[42912,16664]_part_2_-1.jpg','0131,[42912,16664]','CK',0.000653,0.028866,0.939602,0.000325,0.020380,0.972099]
r2 = [2,'0131_[42242,13869]_part_1_-1.jpg','0131,[42242,13869]','CK',0.000190,0.014383,0.975974,0.000128,0.011796,0.984911]
r3 = [3,'0131_[41114,18347]_part_2_-5.jpg','0131,[41114,18347]','CD8',0.000335,0.019394,0.969690,0.000224,0.015862,0.981349]
r4 = [4,'0131_[45930,8132]_part_4_-2.jpg','0131,[45930,8132]','DAPI',0.002778,0.063097,0.801640,0.000614,0.029738,0.955995]
r5 = [5,'0131_[49577,17995]_part_2_-5.jpg','0131,[49577,17995]','CD8',0.001919,0.050938,0.807614,0.000434,0.024246,0.956979]
r6 = [6,'0131_[40968,17689]_part_3_-0.jpg','0131,[40968,17689]','CD4',0.000227,0.015651,0.980994,0.000198,0.014615,0.983728]
r7 = [7,'0131_[47748,18478]_part_4_-5.jpg','0131,[47748,18478]','CD8',0.000335,0.019255,0.967275,0.000174,0.013884,0.984839]
r8 = [8,'0131_[43076,8951]_part_2_-3.jpg','0131,[43076,8951]','CD3',0.001084,0.035522,0.850774,0.000321,0.019350,0.957733]
r9 = [9,'0131_[50072,7868]_part_3_-5.jpg','0131,[50072,7868]','CD8',0.000733,0.028273,0.902026,0.000292,0.017852,0.964223]
r10 = [10,'0131_[44467,14410]_part_1_-4.jpg','0131,[44467,14410]','FoxP3',0.000477,0.023085,0.938726,0.000175,0.013973,0.978849]
import pandas as pd
df = pd.DataFrame(data = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10])
df
Out[5]:
0 1 2 3 4 5 6 7 8 9
0 1 0131_[42912,16664]_part_2_-1.jpg 0131,[42912,16664] CK 0.000653 0.028866 0.939602 0.000325 0.020380 0.972099
1 2 0131_[42242,13869]_part_1_-1.jpg 0131,[42242,13869] CK 0.000190 0.014383 0.975974 0.000128 0.011796 0.984911
2 3 0131_[41114,18347]_part_2_-5.jpg 0131,[41114,18347] CD8 0.000335 0.019394 0.969690 0.000224 0.015862 0.981349
3 4 0131_[45930,8132]_part_4_-2.jpg 0131,[45930,8132] DAPI 0.002778 0.063097 0.801640 0.000614 0.029738 0.955995
4 5 0131_[49577,17995]_part_2_-5.jpg 0131,[49577,17995] CD8 0.001919 0.050938 0.807614 0.000434 0.024246 0.956979
5 6 0131_[40968,17689]_part_3_-0.jpg 0131,[40968,17689] CD4 0.000227 0.015651 0.980994 0.000198 0.014615 0.983728
6 7 0131_[47748,18478]_part_4_-5.jpg 0131,[47748,18478] CD8 0.000335 0.019255 0.967275 0.000174 0.013884 0.984839
7 8 0131_[43076,8951]_part_2_-3.jpg 0131,[43076,8951] CD3 0.001084 0.035522 0.850774 0.000321 0.019350 0.957733
8 9 0131_[50072,7868]_part_3_-5.jpg 0131,[50072,7868] CD8 0.000733 0.028273 0.902026 0.000292 0.017852 0.964223
9 10 0131_[44467,14410]_part_1_-4.jpg 0131,[44467,14410] FoxP3 0.000477 0.023085 0.938726 0.000175 0.013973 0.978849
In [34]:
df.to_csv('data/metrics_sim5.csv')
In [33]:
av_metrics = pd.DataFrame(index = ['Average and median metrics'],
    data={
        'avg MSE': np.mean(df[4]),
        'med MSE': np.median(df[4]),
        'avg NMSE': np.mean(df[5]),
        'med NMSE': np.median(df[5]),
        'avg SSIM': np.mean(df[6]),
        'med SSIM': np.median(df[6]),
        'avg LR_MSE': np.mean(df[7]),
        'med LR_MSE': np.median(df[7]),
        'avg LR_NMSE': np.mean(df[8]),
        'med LR_NMSE': np.median(df[8]),
        'avg LR_SSIM': np.mean(df[9]),
        'med LR_SSIM': np.median(df[9])
    })
av_metrics.to_csv('data/av_metrics_sim5.csv')
av_metrics.transpose()
Out[33]:
Average and median metrics
avg MSE 0.000873
med MSE 0.000565
avg NMSE 0.029846
med NMSE 0.025679
avg SSIM 0.913432
med SSIM 0.939164
avg LR_MSE 0.000289
med LR_MSE 0.000258
avg LR_NMSE 0.018170
med LR_NMSE 0.016857
avg LR_SSIM 0.972071
med LR_SSIM 0.975474
In [31]:
# report average and median MSE, NMSE, SSIM: 
fig = plt.figure(figsize = (20, 10))
ax = fig.add_subplot(1, 3, 1)
ax.plot(df[4],  '-x', label='MSE prediction vs GT')
ax.plot(df[7],'-x', label = 'MSE interpolated HR vs GT')
ax.axhline(np.median(df[7]),color = 'r')
ax.axhline(np.median(df[4]),color = 'r')
plt.xlabel('files')
plt.ylabel('MSE')
ax.legend()

ax = fig.add_subplot(1, 3, 2)
ax.plot(df[5],  '-x', label='NMSE prediction vs GT')
ax.plot(df[8],'-x', label = 'NMSE interpolated HR vs GT')
ax.axhline(np.median(df[8]),color = 'r')
ax.axhline(np.median(df[5]),color = 'r')
plt.xlabel('files')
plt.ylabel('NMSE')
ax.legend()

ax = fig.add_subplot(1, 3, 3)
ax.plot(df[6],  '-x', label='SSIM prediction vs GT')
ax.plot(df[9],'-x', label = 'SSIM interpolated HR vs GT')
ax.axhline(np.median(df[9]),color = 'r')
ax.axhline(np.median(df[6]),color = 'r')
plt.xlabel('files')
plt.ylabel('SSIM')
ax.legend()

plt.suptitle('Interpolated HR vs GT')
plt.savefig('images/intHRvsGT_sim5.png')
In [ ]: